import sys, time, os
import numpy as np
import argparse
import warnings
import random
warnings.filterwarnings("ignore")
from tqdm.auto import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
os.environ["CUDA_VISIBLE_DEVICES"]="2" # set the device as here.

import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AdamW
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import Dataset as HFDataset

from load_data import *
from utils import *

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cora',
                    choices=['cora', 'pubmed', 'ogbn-arxiv', 'ogbn-products'])
parser.add_argument('--max_length', type=int, default=512)
parser.add_argument('--epochs', type=float, default=1000)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--wd', type=float, default=0)
parser.add_argument('--input_mode', type=str, default='t_t_tape',
                    choices=['t_x_x', 'c_x_x', 'tc_x_x', 'ct_x_x', 
                            't_x_tape', 'c_x_tape', 'tc_x_tape', 'ct_x_tape',
                            't_t_x', 'c_t_x', 'tc_t_x', 'ct_t_x',
                            't_t_tape', 'c_t_tape', 'tc_t_tape', 'ct_t_tape',
                            't_tl_x', 'c_tl_x', 'tc_tl_x', 'ct_tl_x',
                            't_tl_tape', 'c_tl_tape', 'tc_tl_tape', 'ct_tl_tape',])
parser.add_argument('--input_label_pool', type=int, default=2,
                    choices=[0,1,2,3], help='Number of labels to choose from the prediction of teacher GNNs')
parser.add_argument('--output_mode', type=str, default='g_x',
                    choices=['g_x', 'd_x', 'p_x', 'g_r', 'p_r', 'g_2l'])
parser.add_argument('--num_neighbors', type=int, default=5)
parser.add_argument('--model_id', type=str, default='small',
                    choices=['small', 'base', 'large'])
parser.add_argument('--patience', type=int, default=3, help='Early stopping patience')

args = parser.parse_args()

def generate(tokenizer, model, valid_dataloader):
    model.eval()
    with torch.no_grad():
        pbar = tqdm(valid_dataloader)
        total_output = []
        for batch_input in pbar:
            inputs = tokenizer(batch_input, 
                                padding='max_length', 
                                truncation=True, 
                                max_length=args.max_length, 
                                return_tensors='pt').to('cuda')
            output_sequences = model.generate(**inputs, do_sample=False)
            output_sequences = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
            total_output += output_sequences
    return total_output

def evaluate_acc(groundtruth, output, label_set):
    groundtruth = [x.lower() for x in groundtruth]
    output = [x.lower() for x in output]
    label_set = [x.lower() for x in label_set]
    if '2l' in args.output_mode:
        rev = True # in this case the pool will be repeated first, and then the true label
    else:
        rev = False

    total = 0
    nonsense = 0
    predictions = []
    for i in output:
        prediction = matching_order(label_set, i, rev=rev) # sort the appearance of the labels in the output
        if len(prediction) < 1:
            prediction = ['None']
            nonsense += 1
        predictions.append(prediction[0])
        total += 1
    acc = accuracy(predictions, groundtruth)
    nonsense = nonsense/total
    return acc, nonsense

# Meta infos
if args.model_id == 'small':
    model_id = "google/flan-t5-small"
elif args.model_id == 'base':
    model_id = "google/flan-t5-base"
elif args.model_id == 'large':
    model_id = "google/flan-t5-large"

epochs = args.epochs
batch_size = args.batch_size
lr = args.lr

dataset_folder = "processed_data"
split_folder = f"raw_data/{args.dataset}/splits"
dataset = args.dataset
input_mode = args.input_mode
output_mode = args.output_mode
num_neighbors = args.num_neighbors
input_label_pool = args.input_label_pool
wd = args.wd

print(f"Model: {model_id}, Dataset: {dataset}, Input mode: {input_mode}, Output mode: {output_mode}, Input label pool: {input_label_pool}")
print(f"Epochs: {epochs}, Batch size: {batch_size}, Learning rate: {lr}, Num neighbors: {num_neighbors}\n")
print(f"Patience: {args.patience}, weight decay: {args.wd}")

if input_label_pool != 0:
    input_mode = str(input_label_pool)+ 'l_' + input_mode
else:
    input_mode = '0l_' + input_mode

# Setup models
tokenizer = AutoTokenizer.from_pretrained(model_id,)

if dataset == 'ogbn-products':
    from templates.products_templates import get_template
else:
    from templates.citation_templates import get_template
prompt_function, input_function, output_function = get_template('c', input_mode, output_mode) # 'c' is for classification task
title_list, content_list, label_list, neighbors_list, rationale_list, gpt_list = load_meta_data_lists(dataset_folder, dataset, input_mode, output_mode)
label_set = set(label_list)
# if 'adj' in input_mode: neighbors_list = [random.shuffle(x) for x in neighbors_list]
neighbors_list = [x[:num_neighbors] for x in neighbors_list]

def tokenize_function(examples):
    model_inputs = tokenizer(examples['input'], padding='max_length', truncation=True, max_length=args.max_length)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples['target'], padding='max_length', truncation=True, max_length=args.max_length)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

if dataset == 'cora' or dataset == 'pubmed':
    seeds = list(range(5))
else:
    seeds = [0,0,0,0,0] # for OGB dataset, no seed used for split, so testing them on the given split for 5 times

accs = []
for seed in seeds:
    print("\nSeed: ", seed)

    # model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=torch.bfloat16,)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id,)

    total_parameters = sum(p.numel() for p in model.parameters())
    print(f"Total number of parameters: {total_parameters}")

    if dataset == 'cora' or dataset == 'pubmed':
        train_idx, valid_idx, test_idx = load_split_idx(split_folder, dataset, seed)
    elif dataset == 'ogbn-arxiv':
        train_idx, valid_idx, test_idx = load_split_idx_ogb(split_folder, dataset)
    elif dataset == 'ogbn-products':
        train_idx, valid_idx, test_idx = load_split_idx_tape(split_folder, dataset)


    print("Generating input and output lists...")

    label_and_prob_list, raw_label_and_prob_list = load_label_and_prob_list(dataset_folder, dataset, seed)    
    lists = (label_set, title_list, content_list, label_list, label_and_prob_list, neighbors_list, rationale_list, gpt_list, raw_label_and_prob_list)

    all_data = []
    for indices in [train_idx, valid_idx, test_idx]:
        input_list, output_list = generate_idx_specific_input_output_list(
            lists=lists,
            functions=(prompt_function, input_function, output_function),
            indices=indices)
        all_data.append(HFDataset.from_dict({"input": input_list, "target": output_list}))
    print()
    print("Template Example: ")
    print(input_list[0])
    print(output_list[0])
    print()
    train_set, valid_set, test_set = all_data[0], all_data[1], all_data[2]
    tokenized_train_set = train_set.map(tokenize_function, batched=True)
    tokenized_valid_set = valid_set.map(tokenize_function, batched=True)
    tokenized_test_set = test_set.map(tokenize_function, batched=True)
    
    output_dir = f'finetuned_models/{model_id}_{dataset}_{seed}_{input_mode}_{output_mode}_{num_neighbors}_{wd}'
    training_args = TrainingArguments(
        output_dir=output_dir,          # Output directory
        num_train_epochs=epochs,              # Total number of training epochs
        per_device_train_batch_size=batch_size,   # Batch size per device during training
        per_device_eval_batch_size=batch_size,    # Batch size for evaluation
        warmup_steps=100,                # Number of warmup steps for learning rate scheduler
        weight_decay=args.wd,               # Strength of weight decay
        gradient_accumulation_steps=1,
        logging_dir='./logs',            # Directory for storing logs
        logging_steps=10,
        learning_rate=lr,              # Learning rate
        evaluation_strategy="steps",     # Evaluation is done (and logged) every "logging_steps"
        eval_steps=500,                   # Evaluation and logging happen every 50 steps
        # bf16=True,                       # Use mixed precision while training
        load_best_model_at_end=True,     # The best model is loaded at the end of training
        save_strategy="steps",
        save_total_limit=1,
        # save_strategy="no",
        disable_tqdm=True,
    )

    print(training_args.device)

    trainer = Trainer(
        model=model,                     # The instantiated 🤗 Transformers model to be trained
        args=training_args,              # Training arguments
        train_dataset=tokenized_train_set,
        eval_dataset=tokenized_valid_set,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]  # Early stopping
    )

    trainer.train()

    trained_model = trainer.model
    test_input_dataloader = DataLoader(test_set['input'], batch_size=batch_size)
    generated_text = generate(tokenizer, trained_model.to('cuda'), test_input_dataloader)
    test_acc, test_nonsense = evaluate_acc([label_list[i] for i in test_idx], generated_text, label_set)
    print(f"Test acc: {test_acc}, nonsense: {test_nonsense}")
    accs.append(test_acc)
print(f"Average test acc: {np.mean(accs)}, std: {np.std(accs)}")